// Copyright (c) 2013, 2015, 2017 The ComPWA Team.
// This file is part of the ComPWA framework, check
// https://github.com/ComPWA/ComPWA/license.txt for details.

#include "MinLogLH.hpp"
#include "Core/FunctionTree/FunctionTree.hpp"
#include "Core/FunctionTree/Intensity.hpp"
#include "Core/FunctionTree/ParameterList.hpp"
#include "Core/Kinematics.hpp"
#include "Core/Particle.hpp"
#include "Data/DataSet.hpp"

namespace ComPWA {
namespace Estimator {

using namespace ComPWA::FunctionTree;

MinLogLH::MinLogLH(std::shared_ptr<ComPWA::Intensity> intensity,
                   const Data::DataSet &datapoints,
                   const Data::DataSet &phsppoints)
    : Intensity(intensity), DataSample(datapoints), PhspDataSample(phsppoints) {

  LOG(INFO) << "MinLogLH::MinLogLH() |  Size of data sample = "
            << datapoints.Weights.size();
}

double MinLogLH::evaluate() {
  double lh(0.0);

  double Norm(0.0);
  if (0 < PhspDataSample.Weights.size()) {
    double PhspIntegral(0.0);
    double WeightSum(0.0);
    auto Intensities = Intensity->evaluate(PhspDataSample.Data);
    auto IntensIter = Intensities.begin();
    for (auto x = PhspDataSample.Weights.begin();
         x != PhspDataSample.Weights.end(); ++x) {
      PhspIntegral += *x * *IntensIter;
      WeightSum += *x;
      ++IntensIter;
    }
    Norm = (std::log(PhspIntegral / WeightSum) * DataSample.Weights.size());
  }
  // calulate data log sum
  double LogSum(0.0);
  auto Intensities = Intensity->evaluate(DataSample.Data);
  for (size_t i = 0; i < DataSample.Weights.size(); ++i) {
    LogSum += std::log(Intensities[i]) * DataSample.Weights[i];
  }
  lh = Norm - LogSum;

  return lh;
}

std::shared_ptr<ComPWA::FunctionTree::FunctionTree>
createMinLogLHEstimatorFunctionTree(std::shared_ptr<OldIntensity> Intensity,
                                    ParameterList DataSampleList,
                                    ParameterList PhspDataSampleList) {
  LOG(DEBUG)
      << "createMinLogLHEstimatorFunctionTree(): constructing FunctionTree!";

  if (0 == DataSampleList.mDoubleValues().size()) {
    LOG(ERROR) << "createMinLogLHEstimatorFunctionTree(): Data sample is "
                  "empty! Please supply some data.";
    return std::shared_ptr<ComPWA::FunctionTree::FunctionTree>(nullptr);
  }
  size_t SampleSize = DataSampleList.mDoubleValue(0)->values().size();

  std::shared_ptr<Value<std::vector<double>>> weights;
  try {
    weights = findMDoubleValue("Weight", DataSampleList);
  } catch (const Exception &e) {
  }

  std::shared_ptr<ComPWA::FunctionTree::FunctionTree> EvaluationTree =
      std::make_shared<ComPWA::FunctionTree::FunctionTree>(
          "LH", std::make_shared<Value<double>>(),
          std::make_shared<AddAll>(ParType::DOUBLE));

  auto dataTree = std::make_shared<ComPWA::FunctionTree::FunctionTree>(
      "DataEvaluation", std::make_shared<Value<double>>(),
      std::make_shared<MultAll>(ParType::DOUBLE));
  dataTree->createLeaf("minusOne", -1, "DataEvaluation");
  dataTree->createNode("Sum",
                       std::shared_ptr<Strategy>(new AddAll(ParType::DOUBLE)),
                       "DataEvaluation");
  dataTree->createNode("WeightedLogIntensities",
                       std::shared_ptr<Strategy>(new MultAll(ParType::MDOUBLE)),
                       "Sum");
  if (weights)
    dataTree->createLeaf("EventWeight", weights, "WeightedLogIntensities");
  dataTree->createNode("Log",
                       std::shared_ptr<Strategy>(new LogOf(ParType::MDOUBLE)),
                       "WeightedLogIntensities");
  dataTree->insertTree(Intensity->createFunctionTree(DataSampleList, ""),
                       "Log");

  EvaluationTree->insertTree(dataTree, "LH");

  // if there is a phasespace sample then do the normalization
  if (0 < PhspDataSampleList.mDoubleValues().size()) {
    double PhspWeightSum(PhspDataSampleList.mDoubleValue(0)->values().size());

    std::shared_ptr<Value<std::vector<double>>> phspweights;
    try {
      phspweights = findMDoubleValue("Weight", PhspDataSampleList);
      PhspWeightSum = std::accumulate(phspweights->values().begin(),
                                      phspweights->values().end(), 0.0);
    } catch (const Exception &e) {
    }

    auto normTree = std::make_shared<ComPWA::FunctionTree::FunctionTree>(
        "Normalization(intensity)", std::make_shared<Value<double>>(),
        std::make_shared<MultAll>(ParType::DOUBLE));
    normTree->createLeaf("N", SampleSize, "Normalization(intensity)");
    normTree->createNode("Log",
                         std::shared_ptr<Strategy>(new LogOf(ParType::DOUBLE)),
                         "Normalization(intensity)");
    normTree->createNode(
        "Integral", std::shared_ptr<Strategy>(new MultAll(ParType::DOUBLE)),
        "Log");
    // normTree->createLeaf("PhspVolume", PhspVolume, "Integral");
    normTree->createLeaf("InverseSampleWeights", 1.0 / PhspWeightSum,
                         "Integral");
    normTree->createNode("Sum",
                         std::shared_ptr<Strategy>(new AddAll(ParType::DOUBLE)),
                         "Integral");
    normTree->createNode(
        "WeightedIntensities",
        std::shared_ptr<Strategy>(new MultAll(ParType::MDOUBLE)), "Sum");
    if (phspweights)
      normTree->createLeaf("EventWeight", phspweights, "WeightedIntensities");
    normTree->insertTree(
        Intensity->createFunctionTree(PhspDataSampleList, "phsp"),
        "WeightedIntensities");

    EvaluationTree->insertTree(normTree, "LH");
  } else {
    LOG(INFO) << "createMinLogLHEstimatorFunctionTree(): phsp sample is empty! "
                 "Skipping normalization and assuming intensity is normalized!";
  }
  LOG(DEBUG) << "createMinLogLHEstimatorFunctionTree(): construction of LH "
                "tree finished! Performing checks ...";
  EvaluationTree->parameter();
  if (!EvaluationTree->sanityCheck()) {
    throw std::runtime_error(
        "createMinLogLHEstimatorFunctionTree(): tree has structural "
        "problems. Sanity check not passed!");
  }
  LOG(DEBUG) << "createMinLogLHEstimatorFunctionTree(): finished!";
  return EvaluationTree;
}

std::tuple<FunctionTreeEstimatorWrapper, FitParameterList>
createMinLogLHFunctionTreeEstimator(
    std::shared_ptr<FunctionTreeIntensityWrapper> Intensity,
    const ComPWA::Data::DataSet &DataSample,
    const ComPWA::Data::DataSet &PhspDataSample) {

  ParameterList DataSampleList(DataSample);
  ParameterList PhspDataSampleList(PhspDataSample);

  auto oldintens = Intensity->getOldIntensity();

  auto ft = createMinLogLHEstimatorFunctionTree(oldintens, DataSampleList,
                                                PhspDataSampleList);

  FitParameterList Pars;
  for (auto x : Intensity->getUserParameterList().doubleParameters()) {
    ComPWA::FitParameter<double> p;
    p.Value = x->value();
    p.Name = x->name();
    p.HasBounds = x->hasBounds();
    if (p.HasBounds) {
      p.Bounds = x->bounds();
    }
    if (x->hasError()) {
      p.Error = x->error();
    }
    p.IsFixed = x->isFixed();
    Pars.push_back(p);
  }

  return std::make_tuple(
      FunctionTreeEstimatorWrapper(ft, Intensity->getParameterList(),
                                   Intensity->getUserParameterList()),
      Pars);
}

std::tuple<FunctionTreeEstimatorWrapper, FitParameterList>
createMinLogLHFunctionTreeEstimator(
    std::shared_ptr<FunctionTreeIntensityWrapper> Intensity,
    const ComPWA::Data::DataSet &DataSample) {

  return createMinLogLHFunctionTreeEstimator(Intensity, DataSample,
                                             Data::DataSet());
}

} // namespace Estimator
} // namespace ComPWA
